{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Decision Tree Classifier\n",
    "\n",
    "Trees are a popular class of algorithm in Machine Learning. In this notebook we build a simple Decision Tree Classifier using `scikit-learn` to show that they can be executed homomorphically using Concrete.\n",
    "\n",
    "Converting a tree working over quantized data to its FHE equivalent takes only a few lines of code thanks to Concrete ML.\n",
    "\n",
    "Let's dive in!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The use case\n",
    "\n",
    "The use case is a spam classification task from OpenML you can find here: https://www.openml.org/d/44\n",
    "\n",
    "Some pre-extracted features (like some word frequencies) are provided as well as a class - `0` for a normal e-mail and `1` for spam - for 4601 samples.\n",
    "\n",
    "Let's first get the data-set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import numpy\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "features, classes = fetch_openml(data_id=44, as_frame=False, cache=True, return_X_y=True)\n",
    "classes = classes.astype(numpy.int64)\n",
    "\n",
    "x_train, x_test, y_train, y_test = train_test_split(\n",
    "    features,\n",
    "    classes,\n",
    "    test_size=0.15,\n",
    "    random_state=42,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's use the sklearn cross-validation tool to find the best hyper parameters for our model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best hyper parameters: {'max_depth': None, 'max_features': None, 'min_samples_leaf': 10, 'min_samples_split': 100}\n",
      "Best score: 0.9295527168344\n"
     ]
    }
   ],
   "source": [
    "# Find best hyper parameters with cross validation\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "from concrete.ml.sklearn import DecisionTreeClassifier as ConcreteDecisionTreeClassifier\n",
    "\n",
    "# List of hyper parameters to tune\n",
    "param_grid = {\n",
    "    \"max_features\": [None, \"sqrt\", \"log2\"],\n",
    "    \"min_samples_leaf\": [1, 10, 100],\n",
    "    \"min_samples_split\": [2, 10, 100],\n",
    "    \"max_depth\": [None, 2, 4, 6, 8],\n",
    "}\n",
    "\n",
    "grid_search = GridSearchCV(\n",
    "    ConcreteDecisionTreeClassifier(),\n",
    "    param_grid,\n",
    "    cv=10,\n",
    "    scoring=\"average_precision\",\n",
    "    error_score=\"raise\",\n",
    "    n_jobs=1,\n",
    ")\n",
    "\n",
    "gs_results = grid_search.fit(x_train, y_train)\n",
    "print(\"Best hyper parameters:\", gs_results.best_params_)\n",
    "print(\"Best score:\", gs_results.best_score_)\n",
    "\n",
    "# Build the model with best hyper parameters\n",
    "model = ConcreteDecisionTreeClassifier(\n",
    "    max_features=gs_results.best_params_[\"max_features\"],\n",
    "    min_samples_leaf=gs_results.best_params_[\"min_samples_leaf\"],\n",
    "    min_samples_split=gs_results.best_params_[\"min_samples_split\"],\n",
    "    max_depth=gs_results.best_params_[\"max_depth\"],\n",
    "    n_bits=6,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's compute some metrics on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, sklearn_model = model.fit_benchmark(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sklearn average precision score: 0.94\n",
      "Concrete average precision score: 0.97\n"
     ]
    }
   ],
   "source": [
    "# Compute average precision on test\n",
    "from sklearn.metrics import average_precision_score\n",
    "\n",
    "# pylint: disable=no-member\n",
    "y_pred_concrete = model.predict_proba(x_test)[:, 1]\n",
    "y_pred_sklearn = sklearn_model.predict_proba(x_test)[:, 1]\n",
    "concrete_average_precision = average_precision_score(y_test, y_pred_concrete)\n",
    "sklearn_average_precision = average_precision_score(y_test, y_pred_sklearn)\n",
    "print(f\"Sklearn average precision score: {sklearn_average_precision:0.2f}\")\n",
    "print(f\"Concrete average precision score: {concrete_average_precision:0.2f}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that Concrete average precision score is not running in FHE here as it would be much longer. If you want to run the model in FHE you can set the argument `fhe` to `execute` in `predict_proba()`. Also, the average precision of the Concrete model may be higher which is likely due to the quantization acting as a kind of regularization which improved the test set metric. However, in general, it should be expected that quantization decreases the average precision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of test samples: 691\n",
      "Number of spams in test samples: 304\n",
      "True Negative (legit mail well classified) rate: 0.9612403100775194\n",
      "False Positive (legit mail classified as spam) rate: 0.03875968992248062\n",
      "False Negative (spam mail classified as legit) rate: 0.14473684210526316\n",
      "True Positive (spam well classified) rate: 0.8552631578947368\n"
     ]
    }
   ],
   "source": [
    "# Show the confusion matrix on x_test\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "y_pred = model.predict(x_test)\n",
    "true_negative, false_positive, false_negative, true_positive = confusion_matrix(\n",
    "    y_test, y_pred, normalize=\"true\"\n",
    ").ravel()\n",
    "\n",
    "num_samples = len(y_test)\n",
    "num_spam = sum(y_test)\n",
    "\n",
    "print(f\"Number of test samples: {num_samples}\")\n",
    "print(f\"Number of spams in test samples: {num_spam}\")\n",
    "\n",
    "print(f\"True Negative (legit mail well classified) rate: {true_negative}\")\n",
    "print(f\"False Positive (legit mail classified as spam) rate: {false_positive}\")\n",
    "print(f\"False Negative (spam mail classified as legit) rate: {false_negative}\")\n",
    "print(f\"True Positive (spam well classified) rate: {true_positive}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now we are ready to go in the FHE domain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from concrete.compiler import check_gpu_available\n",
    "\n",
    "use_gpu_if_available = False\n",
    "device = \"cuda\" if use_gpu_if_available and check_gpu_available() else \"cpu\"\n",
    "\n",
    "# We first compile the model with some data, here the training set\n",
    "circuit = model.compile(x_train, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate the key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating a key for an 8-bit circuit\n"
     ]
    }
   ],
   "source": [
    "print(f\"Generating a key for an {circuit.graph.maximum_integer_bit_width()}-bit circuit\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Key generation time: 0.78 seconds\n"
     ]
    }
   ],
   "source": [
    "time_begin = time.time()\n",
    "circuit.client.keygen(force=False)\n",
    "print(f\"Key generation time: {time.time() - time_begin:.2f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reduce the sample size for a faster total execution time\n",
    "FHE_SAMPLES = 10\n",
    "x_test = x_test[:FHE_SAMPLES]\n",
    "y_pred = y_pred[:FHE_SAMPLES]\n",
    "y_reference = y_test[:FHE_SAMPLES]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Execution time: 0.52 seconds per sample\n"
     ]
    }
   ],
   "source": [
    "# Predict in FHE for a few examples\n",
    "time_begin = time.time()\n",
    "y_pred_fhe = model.predict(x_test, fhe=\"execute\")\n",
    "print(f\"Execution time: {(time.time() - time_begin) / len(x_test):.2f} seconds per sample\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ground truth:       [0 0 0 1 0 1 0 0 0 0]\n",
      "Prediction sklearn: [0 0 0 1 0 1 0 0 0 0]\n",
      "Prediction FHE:     [0 0 0 1 0 1 0 0 0 0]\n"
     ]
    }
   ],
   "source": [
    "# Check prediction FHE vs sklearn\n",
    "print(f\"Ground truth:       {y_reference}\")\n",
    "print(f\"Prediction sklearn: {y_pred}\")\n",
    "print(f\"Prediction FHE:     {y_pred_fhe}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10/10 predictions are similar between the FHE model and the clear sklearn model.\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    f\"{numpy.sum(y_pred_fhe == y_pred)}/\"\n",
    "    \"10 predictions are similar between the FHE model and the clear sklearn model.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, 10 executions over 10 samples are performed to ensure that the FHE inference gives the same results as the clear model. Doing FHE inferences (to get the real FHE precision score) over the full data-set would be too expensive."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "Fully Homomorphic Decision Trees are now within reach for any data scientist familiar with scikit-learn APIs."
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
